# -*- coding: utf-8 -*-
# -------------------------------
# @文件：connect_mysql.py
# @时间：2024/4/9 下午2:37
# @作者：caiweichao
# @功能描述：链接 mysql 数据库工具
# -------------------------------
import pymysql
from pymysql import converters
from pymysql.constants import FIELD_TYPE

from util.basic.log import Log
from util.basic.analysis_yaml import AnalysisYaml


class ConnectMysql:
    def __init__(self, mysql_name):
        conv = converters.conversions
        conv[FIELD_TYPE.NEWDECIMAL] = float
        conv[FIELD_TYPE.DATE] = str
        conv[FIELD_TYPE.TIMESTAMP] = str
        conv[FIELD_TYPE.DATETIME] = str
        conv[FIELD_TYPE.TIME] = str
        conv[FIELD_TYPE.VARCHAR] = str
        try:
            mysql_config = AnalysisYaml().get_mysql_config(mysql_name=mysql_name)
            self.db = pymysql.Connect(
                host=mysql_config["HOST"],
                user=mysql_config["USER"],
                password=str(mysql_config["PASSWORD"]),
                database=None,
                port=int(mysql_config["PORT"]),
                conv=conv)
        except TimeoutError as e:
            Log.error(f'数据库链接超时请检查：{e}')
            raise TimeoutError(f'数据库链接超时请检\n{e}')
        except IndentationError as e:
            Log.error('数据库链接用户名不存在请检查')
            raise IndentationError(f"数据库链接用户名不存在请检查\n{e}")
        except pymysql.err.OperationalError as e:
            Log.error(f'用户名或密码错误请检查\n{e}')
            raise pymysql.err.OperationalError('用户名或密码错误请检查')
        self.cursor = self.db.cursor()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cursor.close()
        self.db.close()

    @staticmethod
    def check_sql(func):
        def warpper(*args, **kwargs):
            try:
                res = func(*args, **kwargs)
                return res
            except pymysql.err.ProgrammingError as e:
                Log.error(f"请检查sql是否正确 sql={args}")
                raise e

        return warpper

    # 查询单条数据并且返回 可以通过sql查询指定的值 也可以通过索引去选择指定的值
    @check_sql
    def fetch_one(self, sql, name=None):
        # 按照sql进行查询
        self.cursor.execute(sql)
        sql_data = self.cursor.fetchone()
        return sql_data

    @check_sql
    def fetch_all(self, sql):  # 查询多条数据并且返回
        # 按照sql进行查询
        self.cursor.execute(sql)
        sql_data = self.cursor.fetchall()
        return sql_data

    @check_sql
    def insert_or_update_data(self, sql):
        cursor = self.db.cursor()
        try:
            # 执行sql
            cursor.execute(sql)
            # 提交到数据库执行
            self.db.commit()
        except pymysql.err.ProgrammingError as e:
            Log.error(f"请检查sql是否正确 sql={sql}")
            raise e
